/// # Examples
/// ```rust
/// use hicollections::{AvlTree, AvlTreeNode, avltree};
/// use std::mem::MaybeUninit;
/// use std::cmp::{PartialEq, PartialOrd, Ordering};
/// use std::ptr;
/// use std::borrow::Borrow;
///
/// struct Foo {
///     val: i32,
///     node: AvlTreeNode,
/// }
///
/// impl Borrow<i32> for Foo {
///     fn borrow(&self) -> &i32 {
///         &self.val
///     }
/// }
///
/// let mut tree = avltree!(Foo, node);
///
/// let mut foos: [Foo; 100] = unsafe { MaybeUninit::uninit().assume_init() };
///
/// for (n, foo) in foos.iter_mut().enumerate() {
///     foo.val = n as i32;
///     unsafe { ptr::write(&mut foo.node, AvlTreeNode::new()); }
///     assert!(unsafe { tree.insert_borrow::<i32>(foo, false) }.is_some());
/// };
///
/// assert!(!tree.empty());
/// assert_eq!(tree.first().unwrap().val, 0);
/// assert_eq!(tree.last().unwrap().val, 99);
///
/// for (n, foo) in tree.iter().enumerate() {
///     assert_eq!(foo.val, n as i32);
/// };
///
/// for (n, foo) in tree.iter().rev().enumerate() {
///     assert_eq!(99 - foo.val, n as i32);
/// };
///
/// assert_eq!(tree.iter().count(), foos.iter().rev().count());
///
/// for foo in &foos {
///     let found = tree.find(&foo.val);
///     assert!(found.is_some());
///     assert!(ptr::eq(found.unwrap(), foo));
/// };
///
/// for foo in &foos {
///     let removed = tree.remove(&foo.val);
///     assert!(removed.is_some());
///     assert!(ptr::eq(removed.unwrap(), foo));
/// };
///
/// assert!(tree.empty());
/// assert!(tree.first().is_none());
/// assert!(tree.last().is_none());
///
/// ```
///

use crate::Link;
use core::borrow::Borrow;
use core::cell::UnsafeCell;
use core::cmp::Ordering;
use core::debug_assert;
use core::marker::{PhantomData, PhantomPinned};
use core::ops::{
    Bound::{self, Excluded, Included, Unbounded},
    RangeBounds,
};
use core::ops::{Deref, DerefMut};
use core::ptr::{self, NonNull};

/// 方便构造一个AvlTree.
/// ```rust
/// use hicollections::{avltree, AvlTree, AvlTreeNode};
/// let tree: AvlTree<AvlTreeNode> = avltree!();
/// struct Foo {
///     val: i32,
///     node: AvlTreeNode,
/// }
/// let tree: AvlTree<Foo> = avltree!(Foo, node);
/// ```
#[macro_export]
macro_rules! avltree {
    ($type: ty, $($mem: ident).*) => {
        $crate::AvlTree::<$type>::new(|node| unsafe {
            ::core::ptr::addr_of!((*node).$($mem).*)
        })
    };
    () => {
        $crate::AvlTree::<AvlTreeNode>::new(|node| node)
    };
}

/// 类似List的设计思路，AvlTree支持Send.
/// 支持基于`Borrow<Q>`实现比较插入，无需强制节点数据结构实现Ord Trait
#[repr(C)]
pub struct AvlTree<T> {
    root: Option<AvlTreeNodeRef>,
    offset: usize,
    mark: PhantomData<*const T>,
}

unsafe impl<T: Sync> Send for AvlTree<T> {}
unsafe impl<T: Sync> Sync for AvlTree<T> {}

/// 需要在用户数据结构中定义此类型的成员才能加入到AvlTree.
#[repr(transparent)]
pub struct AvlTreeNode(UnsafeCell<Inner>);

#[repr(C)]
struct Inner {
    parent: Option<AvlTreeNodeRef>,
    left: usize,
    right: usize,
    mark: PhantomData<*const AvlTreeNode>,
    pin: PhantomPinned,
}

#[derive(Copy, Clone)]
enum Pos {
    Left = 0,
    Right = 1,
}

#[derive(Copy, Clone)]
enum Height {
    Eq = 0,
    Lh = 1,
    Rh = 2,
}

#[derive(Copy, Clone)]
struct AvlTreeNodeRef {
    ptr: NonNull<AvlTreeNode>,
}

impl AvlTreeNodeRef {
    fn as_ptr(&self) -> *mut AvlTreeNode {
        self.ptr.as_ptr()
    }
}

impl From<&'_ AvlTreeNode> for AvlTreeNodeRef {
    fn from(src: &'_ AvlTreeNode) -> Self {
        Self {
            ptr: NonNull::from(src),
        }
    }
}

impl Deref for AvlTreeNodeRef {
    type Target = AvlTreeNode;
    fn deref(&self) -> &Self::Target {
        unsafe { self.ptr.as_ref() }
    }
}

impl DerefMut for AvlTreeNodeRef {
    fn deref_mut(&mut self) -> &mut Self::Target {
        unsafe { self.ptr.as_mut() }
    }
}

/// 实际上这应该是一个unsafe Drop，但当前Rust还不支持.
/// 如果AvlTree还可用，AvlTreeNode应该在析构前确保从AvlTree中删除.
impl Drop for AvlTreeNode {
    fn drop(&mut self) {}
}

impl<T> AvlTree<T> {
    pub fn new<F>(f: F) -> Self
    where
        F: FnOnce(*const T) -> *const AvlTreeNode,
    {
        Self {
            root: None,
            offset: crate::node_offset(f),
            mark: PhantomData,
        }
    }

    pub fn empty(&self) -> bool {
        self.root.is_none()
    }

    pub fn first<'a>(&self) -> Option<&'a T> {
        if let Some(root) = self.root {
            Some(self.obj_from(root.first()))
        } else {
            None
        }
    }

    pub fn last<'a>(&self) -> Option<&'a T> {
        if let Some(root) = self.root {
            Some(self.obj_from(root.last()))
        } else {
            None
        }
    }

    /// # Safety
    /// 插入树的节点，需要保证在析构前从树中删除，或者生命周期长于树的生命周期,
    /// 并且需要保证删除前不能发生所有权转移.
    pub unsafe fn insert<'a>(&mut self, item: &'a T, overwrite: bool) -> Option<Link<'a>>
    where
        T: Ord,
    {
        self.insert_borrow::<T>(item, overwrite)
    }

    /// # Safety
    /// 插入树的节点，需要保证在析构前从树中删除，或者生命周期长于树的生命周期
    /// 并且需要保证删除前不能发生所有权转移.
    pub unsafe fn insert_borrow<'a, Q>(&mut self, item: &'a T, overwrite: bool) -> Option<Link<'a>>
    where
        T: Borrow<Q>,
        Q: Ord,
    {
        let node = self.node_from(item);
        debug_assert!(!node.linked());

        let key: &Q = item.borrow();
        if let Some(root) = self.root {
            match root.search(|node| key.cmp(self.obj_from(node).borrow())) {
                (parent, Ordering::Less) => {
                    parent.link_left(Some(node));
                    self.insert_rotation(node);
                }
                (parent, Ordering::Greater) => {
                    parent.link_right(Some(node));
                    self.insert_rotation(node);
                }
                (found, Ordering::Equal) => {
                    if overwrite {
                        self.replace_with(found, node);
                    } else {
                        return None;
                    }
                }
            }
        } else {
            self.set_root(Some(node));
        }
        Some(Link::new())
    }

    pub fn remove<'a, Q>(&mut self, key: &Q) -> Option<&'a T>
    where
        Q: Ord,
        T: Borrow<Q>,
    {
        let ret = self.find(key);
        if let Some(found) = ret {
            let found = self.node_from(found);
            self.del(found);
        }
        ret
    }

    ///
    /// 提供先序遍历机制，
    ///
    pub fn walk<F>(&self, mut f: F)
    where
        F: FnMut(&T),
    {
        if let Some(root) = self.root {
            root.walk(|node, _| f(self.obj_from(node)));
        }
    }

    ///
    /// 验证是否满足平衡树的规则, 任何节点左右子树高度最多相差1
    ///
    pub fn assert(&self, msg: &str) {
        if let Some(root) = self.root {
            root.walk(|node, _| {
                let rh = Self::tree_height(node.get_right());
                let lh = Self::tree_height(node.get_left());
                assert!(
                    lh == rh || lh + 1 == rh || rh + 1 == lh,
                    "rh {rh}, lh {lh} msg {msg}"
                );
            });
        }
    }

    fn tree_height(node: Option<AvlTreeNodeRef>) -> usize {
        if let Some(node) = node {
            return node.get_absolute_height();
        }
        0
    }

    pub fn find<'a, Q>(&self, key: &Q) -> Option<&'a T>
    where
        Q: Ord,
        T: Borrow<Q>,
    {
        match self
            .root?
            .search(|node| key.cmp(self.obj_from(node).borrow()))
        {
            (found, Ordering::Equal) => Some(self.obj_from(found)),
            _ => None,
        }
    }

    pub fn iter(&self) -> Iter<'_, T> {
        Iter::new(self, self.first_node(), None)
    }

    pub fn iter_mut(&mut self) -> IterMut<'_, T> {
        IterMut::new(self, self.first_node(), None)
    }

    pub fn range<Q, R>(&self, range: R) -> Iter<'_, T>
    where
        Q: Ord,
        T: Borrow<Q>,
        R: RangeBounds<Q>,
    {
        Iter::new(
            self,
            self.find_start(range.start_bound()),
            self.find_end(range.end_bound()),
        )
    }

    pub fn range_mut<Q, R>(&mut self, range: R) -> IterMut<'_, T>
    where
        Q: Ord,
        T: Borrow<Q>,
        R: RangeBounds<Q>,
    {
        IterMut::new(
            self,
            self.find_start(range.start_bound()),
            self.find_end(range.end_bound()),
        )
    }

    fn first_node(&self) -> Option<AvlTreeNodeRef> {
        Some(self.root?.first())
    }

    fn find_start<Q>(&self, key: Bound<&Q>) -> Option<AvlTreeNodeRef>
    where
        Q: Ord,
        T: Borrow<Q>,
    {
        let root = self.root?;
        match key {
            Included(key) => {
                let (node, ordering) = root.search(|node| key.cmp(self.obj_from(node).borrow()));
                if ordering == Ordering::Greater {
                    node.next()
                } else {
                    Some(node)
                }
            }
            Excluded(key) => {
                let (node, ordering) = root.search(|node| key.cmp(self.obj_from(node).borrow()));
                if ordering != Ordering::Less {
                    node.next()
                } else {
                    Some(node)
                }
            }
            Unbounded => Some(root.first()),
        }
    }

    fn find_end<Q>(&self, key: Bound<&Q>) -> Option<AvlTreeNodeRef>
    where
        Q: Ord,
        T: Borrow<Q>,
    {
        let root = self.root?;
        match key {
            Included(key) => {
                let (node, ordering) = root.search(|node| key.cmp(self.obj_from(node).borrow()));
                if ordering != Ordering::Less {
                    node.next()
                } else {
                    Some(node)
                }
            }
            Excluded(key) => {
                let (node, ordering) = root.search(|node| key.cmp(self.obj_from(node).borrow()));
                if ordering == Ordering::Greater {
                    node.next()
                } else {
                    Some(node)
                }
            }
            Unbounded => None,
        }
    }

    fn node_from(&self, obj: &T) -> AvlTreeNodeRef {
        let addr = obj as *const _ as usize + self.offset;
        unsafe { &*(addr as *const AvlTreeNode) }.into()
    }

    fn obj_from<'a>(&self, node: AvlTreeNodeRef) -> &'a T {
        let addr = node.as_ptr() as usize - self.offset;
        unsafe { &*(addr as *const T) }
    }

    fn obj_from_mut<'a>(&self, node: AvlTreeNodeRef) -> &'a mut T {
        let addr = node.as_ptr() as usize - self.offset;
        unsafe { &mut *(addr as *mut T) }
    }

    fn replace_with(&mut self, old: AvlTreeNodeRef, new: AvlTreeNodeRef) {
        debug_assert!(!new.linked());
        self.link_child(old.get_parent(), Some(new), old.get_pos());
        new.set_height(old.get_height());
        new.link_left(old.get_left());
        new.link_right(old.get_right());
        old.init_node();
    }

    fn set_root(&mut self, node: Option<AvlTreeNodeRef>) {
        self.root = node;
        if let Some(node) = node {
            node.set_null_parent();
        }
    }

    fn link_child(
        &mut self,
        parent: Option<AvlTreeNodeRef>,
        node: Option<AvlTreeNodeRef>,
        pos: Pos,
    ) {
        if let Some(parent) = parent {
            parent.link_child(node, pos);
        } else {
            self.set_root(node);
        }
    }

    fn insert_rotation(&mut self, mut node: AvlTreeNodeRef) {
        node.set_height(Height::Eq);
        while let Some(parent) = node.get_parent() {
            match parent.get_height() {
                Height::Eq => {
                    match node.get_pos() {
                        Pos::Left => parent.set_height(Height::Lh),
                        _ => parent.set_height(Height::Rh),
                    }
                    node = parent;
                    continue;
                }
                Height::Lh => {
                    if matches!(node.get_pos(), Pos::Right) {
                        parent.set_height(Height::Eq);
                    } else if matches!(node.get_height(), Height::Rh) {
                        let right = node.get_right().unwrap();
                        self.rr_rotation(right, node, parent);
                        match right.get_height() {
                            Height::Eq => {
                                parent.set_height(Height::Eq);
                                node.set_height(Height::Eq);
                            }
                            Height::Lh => {
                                parent.set_height(Height::Rh);
                                node.set_height(Height::Eq);
                            }
                            Height::Rh => {
                                parent.set_height(Height::Eq);
                                node.set_height(Height::Lh);
                            }
                        }
                        right.set_height(Height::Eq);
                    } else {
                        assert!(!matches!(node.get_height(), Height::Eq));
                        self.lr_rotation(node, parent);
                        node.set_height(Height::Eq);
                        parent.set_height(Height::Eq);
                    }
                }
                Height::Rh => {
                    if matches!(node.get_pos(), Pos::Left) {
                        parent.set_height(Height::Eq);
                    } else if matches!(node.get_height(), Height::Lh) {
                        let left = node.get_left().unwrap();
                        self.ll_rotation(left, node, parent);
                        match left.get_height() {
                            Height::Eq => {
                                parent.set_height(Height::Eq);
                                node.set_height(Height::Eq);
                            }
                            Height::Lh => {
                                parent.set_height(Height::Eq);
                                node.set_height(Height::Rh);
                            }
                            Height::Rh => {
                                parent.set_height(Height::Lh);
                                node.set_height(Height::Eq);
                            }
                        }
                        left.set_height(Height::Eq);
                    } else {
                        assert!(!matches!(node.get_height(), Height::Eq));
                        self.rl_rotation(node, parent);
                        node.set_height(Height::Eq);
                        parent.set_height(Height::Eq);
                    }
                }
            }
            return;
        }
    }

    fn del_rotation(&mut self, mut pnode: Option<AvlTreeNodeRef>, mut pos: Pos) {
        while let Some(parent) = pnode {
            pnode = parent.get_parent();
            match parent.get_height() {
                Height::Eq => {
                    if matches!(pos, Pos::Left) {
                        parent.set_height(Height::Rh);
                    } else {
                        parent.set_height(Height::Lh);
                    }
                    return;
                }
                Height::Lh => {
                    if matches!(pos, Pos::Left) {
                        parent.set_height(Height::Eq);
                        pos = parent.get_pos();
                        continue;
                    } else {
                        let node = parent.get_left().unwrap();
                        match node.get_height() {
                            Height::Rh => {
                                let right = node.get_right().unwrap();
                                self.rr_rotation(right, node, parent);
                                match right.get_height() {
                                    Height::Eq => {
                                        node.set_height(Height::Eq);
                                        parent.set_height(Height::Eq);
                                    }
                                    Height::Lh => {
                                        node.set_height(Height::Eq);
                                        parent.set_height(Height::Rh);
                                    }
                                    Height::Rh => {
                                        node.set_height(Height::Lh);
                                        parent.set_height(Height::Eq);
                                    }
                                }
                                right.set_height(Height::Eq);
                                pos = right.get_pos();
                                continue;
                            }
                            Height::Lh => {
                                self.lr_rotation(node, parent);
                                parent.set_height(Height::Eq);
                                node.set_height(Height::Eq);
                                pos = node.get_pos();
                                continue;
                            }
                            Height::Eq => {
                                self.lr_rotation(node, parent);
                                node.set_height(Height::Rh);
                                parent.set_height(Height::Lh);
                                return;
                            }
                        }
                    }
                }
                Height::Rh => {
                    if matches!(pos, Pos::Right) {
                        parent.set_height(Height::Eq);
                        pos = parent.get_pos();
                        continue;
                    } else {
                        let node = parent.get_right().unwrap();
                        match node.get_height() {
                            Height::Lh => {
                                let left = node.get_left().unwrap();
                                self.ll_rotation(left, node, parent);
                                match left.get_height() {
                                    Height::Eq => {
                                        node.set_height(Height::Eq);
                                        parent.set_height(Height::Eq);
                                    }
                                    Height::Lh => {
                                        node.set_height(Height::Rh);
                                        parent.set_height(Height::Eq);
                                    }
                                    Height::Rh => {
                                        node.set_height(Height::Eq);
                                        parent.set_height(Height::Lh);
                                    }
                                }
                                left.set_height(Height::Eq);
                                pos = left.get_pos();
                                continue;
                            }
                            Height::Rh => {
                                self.rl_rotation(node, parent);
                                node.set_height(Height::Eq);
                                parent.set_height(Height::Eq);
                                pos = node.get_pos();
                                continue;
                            }
                            Height::Eq => {
                                self.rl_rotation(node, parent);
                                node.set_height(Height::Lh);
                                parent.set_height(Height::Rh);
                                return;
                            }
                        }
                    }
                }
            }
        }
    }

    #[inline]
    fn lr_rotation(&mut self, node: AvlTreeNodeRef, parent: AvlTreeNodeRef) {
        debug_assert!(ptr::eq(node.as_ptr(), parent.get_left().unwrap().as_ptr()));
        self.link_child(parent.get_parent(), Some(node), parent.get_pos());
        parent.link_left(node.get_right());
        node.link_right(Some(parent));
    }

    #[inline]
    fn rr_rotation(
        &mut self,
        node: AvlTreeNodeRef,
        parent: AvlTreeNodeRef,
        pparent: AvlTreeNodeRef,
    ) {
        debug_assert!(ptr::eq(node.as_ptr(), parent.get_right().unwrap().as_ptr()));
        debug_assert!(ptr::eq(
            parent.as_ptr(),
            pparent.get_left().unwrap().as_ptr()
        ));
        self.link_child(pparent.get_parent(), Some(node), pparent.get_pos());
        parent.link_right(node.get_left());
        pparent.link_left(node.get_right());
        node.link_left(Some(parent));
        node.link_right(Some(pparent));
    }

    #[inline]
    fn rl_rotation(&mut self, node: AvlTreeNodeRef, parent: AvlTreeNodeRef) {
        debug_assert!(ptr::eq(node.as_ptr(), parent.get_right().unwrap().as_ptr()));
        self.link_child(parent.get_parent(), Some(node), parent.get_pos());
        parent.link_right(node.get_left());
        node.link_left(Some(parent));
    }

    #[inline]
    fn ll_rotation(
        &mut self,
        node: AvlTreeNodeRef,
        parent: AvlTreeNodeRef,
        pparent: AvlTreeNodeRef,
    ) {
        debug_assert!(ptr::eq(node.as_ptr(), parent.get_left().unwrap().as_ptr()));
        debug_assert!(ptr::eq(
            parent.as_ptr(),
            pparent.get_right().unwrap().as_ptr()
        ));
        self.link_child(pparent.get_parent(), Some(node), pparent.get_pos());
        parent.link_left(node.get_right());
        pparent.link_right(node.get_left());
        node.link_right(Some(parent));
        node.link_left(Some(pparent));
    }

    fn del(&mut self, node: AvlTreeNodeRef) {
        let parent = node.get_parent();
        let (parent, pos) = self.prepare_del(node, parent);
        self.del_rotation(parent, pos);
        node.init_node();
    }

    fn prepare_del(
        &mut self,
        node: AvlTreeNodeRef,
        parent: Option<AvlTreeNodeRef>,
    ) -> (Option<AvlTreeNodeRef>, Pos) {
        if node.get_left().is_none() {
            self.link_child(parent, node.get_right(), node.get_pos());
            (parent, node.get_pos())
        } else if node.get_right().is_none() {
            self.link_child(parent, node.get_left(), node.get_pos());
            (parent, node.get_pos())
        } else if matches!(node.get_height(), Height::Lh) {
            self.swap_next_del(node, node.next().unwrap(), parent)
        } else {
            self.swap_prev_del(node, node.prev().unwrap(), parent)
        }
    }

    #[inline]
    fn swap_next_del(
        &mut self,
        node: AvlTreeNodeRef,
        item: AvlTreeNodeRef,
        parent: Option<AvlTreeNodeRef>,
    ) -> (Option<AvlTreeNodeRef>, Pos) {
        let item_parent = item.get_parent().unwrap();
        item.set_height(node.get_height());
        if !ptr::eq(item_parent.as_ptr(), node.as_ptr()) {
            item_parent.link_left(item.get_right());
            self.link_child(parent, Some(item), node.get_pos());
            item.link_left(node.get_left());
            item.link_right(node.get_right());
            (Some(item_parent), Pos::Left)
        } else {
            self.link_child(parent, Some(item), node.get_pos());
            item.link_left(node.get_left());
            (Some(item), Pos::Right)
        }
    }

    #[inline]
    fn swap_prev_del(
        &mut self,
        node: AvlTreeNodeRef,
        item: AvlTreeNodeRef,
        parent: Option<AvlTreeNodeRef>,
    ) -> (Option<AvlTreeNodeRef>, Pos) {
        let item_parent = item.get_parent().unwrap();
        item.set_height(node.get_height());
        if !ptr::eq(item_parent.as_ptr(), node.as_ptr()) {
            item_parent.link_right(item.get_left());
            self.link_child(parent, Some(item), node.get_pos());
            item.link_left(node.get_left());
            item.link_right(node.get_right());
            (Some(item_parent), Pos::Right)
        } else {
            self.link_child(parent, Some(item), node.get_pos());
            item.link_right(node.get_right());
            (Some(item), Pos::Left)
        }
    }
}

impl Default for AvlTreeNode {
    fn default() -> Self {
        Self::new()
    }
}

impl AvlTreeNode {
    pub const fn new() -> Self {
        Self(UnsafeCell::new(Inner {
            parent: None,
            left: 0,
            right: 0,
            pin: PhantomPinned,
            mark: PhantomData,
        }))
    }

    fn next(&self) -> Option<AvlTreeNodeRef> {
        if let Some(right) = self.get_right() {
            Some(right.first())
        } else if matches!(self.get_pos(), Pos::Left) {
            self.get_parent()
        } else {
            let mut node = AvlTreeNodeRef::from(self);
            while let Some(parent) = node.get_parent() {
                if matches!(parent.get_pos(), Pos::Left) {
                    return parent.get_parent();
                }
                node = parent;
            }
            return None;
        }
    }

    fn prev(&self) -> Option<AvlTreeNodeRef> {
        if let Some(left) = self.get_left() {
            Some(left.last())
        } else if matches!(self.get_pos(), Pos::Right) {
            self.get_parent()
        } else {
            let mut node = AvlTreeNodeRef::from(self);
            while let Some(parent) = node.get_parent() {
                if matches!(parent.get_pos(), Pos::Right) {
                    return parent.get_parent();
                }
                node = parent;
            }
            return None;
        }
    }

    fn walk<F>(&self, mut f: F)
    where
        F: FnMut(AvlTreeNodeRef, usize),
    {
        let mut height = 0;
        let mut node = AvlTreeNodeRef::from(self);
        'start: loop {
            height += 1;
            f(node, height);
            if let Some(left) = node.get_left() {
                node = left;
            } else if let Some(right) = node.get_right() {
                node = right;
            } else {
                while let Some(parent) = node.get_parent() {
                    if ptr::eq(node.as_ptr(), self) {
                        return;
                    }
                    height -= 1;
                    if matches!(node.get_pos(), Pos::Left) {
                        if let Some(sibling) = parent.get_right() {
                            node = sibling;
                            continue 'start;
                        }
                    }
                    node = parent;
                }
                return;
            }
        }
    }

    fn link_child(&self, child: Option<AvlTreeNodeRef>, pos: Pos) {
        match pos {
            Pos::Left => self.link_left(child),
            _ => self.link_right(child),
        }
    }

    fn search<F>(&self, comp: F) -> (AvlTreeNodeRef, Ordering)
    where
        F: Fn(AvlTreeNodeRef) -> Ordering,
    {
        let mut cur = AvlTreeNodeRef::from(self);
        loop {
            match comp(cur) {
                Ordering::Less => {
                    if let Some(left) = cur.get_left() {
                        cur = left;
                    } else {
                        return (cur, Ordering::Less);
                    }
                }
                Ordering::Greater => {
                    if let Some(right) = cur.get_right() {
                        cur = right;
                    } else {
                        return (cur, Ordering::Greater);
                    }
                }
                Ordering::Equal => {
                    return (cur, Ordering::Equal);
                }
            }
        }
    }

    #[inline]
    fn first(&self) -> AvlTreeNodeRef {
        if let Some(mut left) = self.get_left() {
            while let Some(item) = left.get_left() {
                left = item;
            }
            left
        } else {
            self.into()
        }
    }

    #[inline]
    fn last(&self) -> AvlTreeNodeRef {
        if let Some(mut right) = self.get_right() {
            while let Some(item) = right.get_right() {
                right = item;
            }
            right
        } else {
            self.into()
        }
    }

    fn get_parent(&self) -> Option<AvlTreeNodeRef> {
        self.inner().parent
    }

    fn get_left(&self) -> Option<AvlTreeNodeRef> {
        let addr = self.inner().left & !0x3;
        if addr > 0 {
            return Some(unsafe { &*(addr as *const Self) }.into());
        }
        None
    }

    fn get_right(&self) -> Option<AvlTreeNodeRef> {
        let addr = self.inner().right & !0x3;
        if addr > 0 {
            return Some(unsafe { &*(addr as *const Self) }.into());
        }
        None
    }

    fn link_left(&self, left: Option<AvlTreeNodeRef>) {
        let inner = unsafe { &mut *self.0.get() };
        if let Some(left) = left {
            inner.left = (left.as_ptr() as usize) | (inner.left & 0x03);
            left.set_parent(self.into(), Pos::Left);
        } else {
            inner.left &= 0x03;
        }
    }

    fn link_right(&self, right: Option<AvlTreeNodeRef>) {
        let inner = unsafe { &mut *self.0.get() };
        if let Some(right) = right {
            inner.right = (right.as_ptr() as usize) | (inner.right & 0x03);
            right.set_parent(self.into(), Pos::Right);
        } else {
            inner.right &= 0x03;
        }
    }

    fn set_parent(&self, parent: AvlTreeNodeRef, pos: Pos) {
        let inner = unsafe { &mut *self.0.get() };
        inner.parent = Some(parent);
        inner.left = (inner.left & !0x03) | (pos as usize);
    }

    fn get_pos(&self) -> Pos {
        match self.inner().left & 0x03 {
            0x00 => Pos::Left,
            _ => Pos::Right,
        }
    }

    fn get_height(&self) -> Height {
        match self.inner().right & 0x03 {
            0x00 => Height::Eq,
            0x01 => Height::Lh,
            _ => Height::Rh,
        }
    }

    fn get_absolute_height(&self) -> usize {
        let mut height = 0;
        let height_ref = &mut height;
        self.walk(|node, height| {
            if node.get_left().is_none() && node.get_right().is_none() && *height_ref < height {
                *height_ref = height;
            }
        });
        height
    }

    fn set_height(&self, height: Height) {
        let inner = unsafe { &mut *self.0.get() };
        inner.right = (inner.right & !0x03) | height as usize;
    }

    fn set_null_parent(&self) {
        let inner = unsafe { &mut *self.0.get() };
        inner.parent = None;
    }

    fn init_node(&self) {
        let inner = unsafe { &mut *self.0.get() };
        inner.parent = None;
        inner.left = 0;
        inner.right = 0;
    }

    fn inner(&self) -> &Inner {
        unsafe { &*self.0.get() }
    }

    fn linked(&self) -> bool {
        let inner = self.inner();
        inner.parent.is_some() || inner.left != 0 || inner.right != 0
    }
}

pub struct Iter<'a, T> {
    tree: &'a AvlTree<T>,
    start: Option<AvlTreeNodeRef>,
    end: Option<AvlTreeNodeRef>,
}

impl<'a, T> Iter<'a, T> {
    fn new(
        tree: &'a AvlTree<T>,
        start: Option<AvlTreeNodeRef>,
        end: Option<AvlTreeNodeRef>,
    ) -> Self {
        Self { tree, start, end }
    }
}

fn pos_equal(lhs: Option<AvlTreeNodeRef>, rhs: Option<AvlTreeNodeRef>) -> bool {
    lhs.map(|node| node.as_ptr()) == rhs.map(|node| node.as_ptr())
}

impl<'a, T> Iterator for Iter<'a, T> {
    type Item = &'a T;
    fn next(&mut self) -> Option<Self::Item> {
        if let Some(start) = self.start {
            if !pos_equal(self.start, self.end) {
                self.start = start.next();
                return Some(self.tree.obj_from(start));
            }
        }
        None
    }
}

impl<'a, T> DoubleEndedIterator for Iter<'a, T> {
    fn next_back(&mut self) -> Option<Self::Item> {
        if pos_equal(self.start, self.end) {
            return None;
        }
        if let Some(end) = self.end {
            if let Some(prev) = end.prev() {
                self.end = Some(prev);
                return Some(self.tree.obj_from(prev));
            }
            None
        } else {
            self.end = Some(self.tree.root?.last());
            self.end.map(|node| self.tree.obj_from(node))
        }
    }
}

impl<'a, T> IntoIterator for &'a AvlTree<T> {
    type Item = &'a T;
    type IntoIter = Iter<'a, T>;
    fn into_iter(self) -> Self::IntoIter {
        self.iter()
    }
}

pub struct IterMut<'a, T> {
    tree: &'a AvlTree<T>,
    start: Option<AvlTreeNodeRef>,
    end: Option<AvlTreeNodeRef>,
}

impl<'a, T> IterMut<'a, T> {
    fn new(
        tree: &'a AvlTree<T>,
        start: Option<AvlTreeNodeRef>,
        end: Option<AvlTreeNodeRef>,
    ) -> Self {
        Self { tree, start, end }
    }
}

impl<'a, T> Iterator for IterMut<'a, T> {
    type Item = &'a mut T;
    fn next(&mut self) -> Option<Self::Item> {
        if let Some(start) = self.start {
            if !pos_equal(self.start, self.end) {
                self.start = start.next();
                return Some(self.tree.obj_from_mut(start));
            }
        }
        None
    }
}

impl<'a, T> DoubleEndedIterator for IterMut<'a, T> {
    fn next_back(&mut self) -> Option<Self::Item> {
        if pos_equal(self.start, self.end) {
            return None;
        }
        if let Some(end) = self.end {
            if let Some(prev) = end.prev() {
                self.end = Some(prev);
                return Some(self.tree.obj_from_mut(prev));
            }
            None
        } else {
            self.end = Some(self.tree.root?.last());
            self.end.map(|node| self.tree.obj_from_mut(node))
        }
    }
}

impl<'a, T> IntoIterator for &'a mut AvlTree<T> {
    type Item = &'a mut T;
    type IntoIter = IterMut<'a, T>;
    fn into_iter(self) -> Self::IntoIter {
        self.iter_mut()
    }
}

#[cfg(test)]
mod test {
    extern crate std;
    use crate::AvlTreeNode;
    use rand;
    use std::borrow::Borrow;
    use std::cmp::{Ordering, PartialEq, PartialOrd};
    use std::format;
    use std::mem::MaybeUninit;
    use std::println;
    use std::ptr;

    #[repr(C)]
    struct Foo {
        val: i32,
        node: AvlTreeNode,
    }

    impl Borrow<i32> for Foo {
        fn borrow(&self) -> &i32 {
            &self.val
        }
    }

    impl Eq for Foo {}
    impl Ord for Foo {
        fn cmp(&self, other: &Self) -> Ordering {
            self.val.cmp(&other.val)
        }
    }

    impl PartialOrd for Foo {
        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
            self.val.partial_cmp(&other.val)
        }
    }

    impl PartialEq for Foo {
        fn eq(&self, other: &Self) -> bool {
            self.val == other.val
        }
    }

    #[test]
    fn test_insert_0_99() {
        let mut tree = avltree!(Foo, node);

        let mut foos: [Foo; 100] = unsafe { MaybeUninit::zeroed().assume_init() };
        foos.iter_mut().fold(0, |n, foo| {
            foo.val = n;
            let link = unsafe { tree.insert_borrow::<i32>(foo, false) };
            assert!(link.is_some());
            tree.assert(format!("insert {}", foo.val).as_str());
            n + 1
        });

        foos.iter().for_each(|foo| {
            let found = tree.find(foo);
            assert!(found.is_some());
            assert!(ptr::eq(found.unwrap(), foo));
        });

        let first = tree.first().unwrap();
        let last = tree.last().unwrap();
        assert_eq!(first.val, 0);
        assert_eq!(last.val, 99);
    }

    #[test]
    fn test_insert_99_0() {
        let mut tree = avltree!(Foo, node);

        let mut foos: [Foo; 100] = unsafe { MaybeUninit::zeroed().assume_init() };
        foos.iter_mut().fold(0, |n, foo| {
            foo.val = 100 - n;
            let link = unsafe { tree.insert(foo, false) };
            assert!(link.is_some());
            tree.assert(format!("insert {}", foo.val).as_str());
            n + 1
        });

        foos.iter().for_each(|foo| {
            let found = tree.find(foo);
            assert!(found.is_some());
            assert!(ptr::eq(found.unwrap(), foo));
        });

        let first = tree.first().unwrap();
        let last = tree.last().unwrap();
        assert_eq!(first.val, 1);
        assert_eq!(last.val, 100);
    }

    #[test]
    fn test_del_0_99() {
        let mut tree = avltree!(Foo, node);

        let mut foos: [Foo; 1000] = unsafe { MaybeUninit::zeroed().assume_init() };

        foos.iter_mut().fold(0, |n, foo| {
            foo.val += n + rand::random::<i32>();
            println!("insert foo.val = {}", foo.val);
            while unsafe { tree.insert(foo, false) }.is_none() {
                foo.val += 3;
                println!("reset foo.val = {}", foo.val);
            }
            tree.assert(format!("insert {}", foo.val).as_str());
            n + 1
        });
        tree.assert("after insert");
        foos.iter().for_each(|foo| {
            let found = tree.find(foo);
            assert!(found.is_some());
            assert!(ptr::eq(found.unwrap(), foo));

            tree.assert(format!("before remove {}", foo.val).as_str());
            println!(
                "remove: {} height: {} pos: {}",
                foo.val,
                foo.node.get_absolute_height(),
                foo.node.get_pos() as usize
            );
            let ret = tree.remove(foo);
            assert!(ret.is_some());
            assert!(ptr::eq(ret.unwrap(), foo));
            tree.assert(format!("after remove {}", foo.val).as_str());
        });

        assert!(tree.first().is_none());
        assert!(tree.last().is_none());
        assert!(tree.empty());
    }

    #[test]
    fn test_iter() {
        let mut tree = avltree!(Foo, node);

        let mut foos: [Foo; 100] = unsafe { MaybeUninit::zeroed().assume_init() };
        for (n, foo) in foos.iter_mut().enumerate() {
            foo.val = n as i32;
            unsafe { tree.insert(foo, false) };
            tree.assert(format!("insert {}", foo.val).as_str());
        }

        let iter = tree.iter();
        let min = tree.first().unwrap().val - 1;
        assert_eq!(min, -1);

        assert_eq!(tree.iter().count(), 100);
        for (n, foo) in iter.enumerate() {
            assert_eq!(foo.val, n as i32);
        }

        for (n, foo) in tree.iter().rev().enumerate() {
            assert_eq!(99 - foo.val, n as i32);
        }

        assert_eq!(tree.iter().rev().count(), 100);
    }

    #[test]
    fn test_range() {
        let mut tree = avltree!(Foo, node);
        let mut foos: [Foo; 100] = unsafe { MaybeUninit::zeroed().assume_init() };
        for (n, foo) in foos.iter_mut().enumerate() {
            foo.val = n as i32;
            unsafe { tree.insert(foo, false) };
        }

        assert_eq!(tree.range(0..).count(), 100);
        assert_eq!(tree.range(-1..).rev().count(), 100);
        assert_eq!(tree.range(100..201).count(), 0);
        assert_eq!(tree.range(99..100).count(), 1);
        assert_eq!(tree.range(0..0).count(), 0);
        assert_eq!(tree.range(0..=0).count(), 1);
        assert_eq!(tree.range(0..1).count(), 1);

        assert_eq!(0, tree.range(100..101).count());
        assert_eq!(1, tree.range(99..101).count());
        assert_eq!(100, tree.range(0..101).count());
        assert_eq!(50, tree.range(1..51).count());

        let mut iter = tree.range(50..=50);
        assert_eq!(iter.next_back().unwrap().val, 50);
        assert!(iter.next_back().is_none());
        assert!(iter.next().is_none());
    }

    #[test]
    fn test_iter_mut() {
        let mut tree = avltree!(Foo, node);

        let mut foos: [Foo; 100] = unsafe { MaybeUninit::zeroed().assume_init() };
        for (n, foo) in foos.iter_mut().enumerate() {
            foo.val = n as i32;
            unsafe { tree.insert(foo, false) };
            tree.assert(format!("insert {}", foo.val).as_str());
        }

        for foo in &mut tree {
            foo.val += 1;
        }

        assert!(tree.find(&0).is_none());
        assert!(tree.find(&1).is_some());
    }

    #[test]
    fn test_replace() {
        let mut tree = avltree!(Foo, node);
        let foo = Foo {
            val: 100,
            node: AvlTreeNode::new(),
        };
        unsafe { tree.insert(&foo, false) };

        let bar = Foo {
            val: 100,
            node: AvlTreeNode::new(),
        };
        assert!(unsafe { tree.insert(&bar, false) }.is_none());
        assert!(unsafe { tree.insert(&bar, true) }.is_some());

        let found = tree.find(&foo);
        assert!(found.is_some());
        assert!(ptr::eq(found.unwrap(), &bar));

        let baz = Foo {
            val: 101,
            node: AvlTreeNode::new(),
        };
        assert!(unsafe { tree.insert(&baz, false) }.is_some());
        assert!(unsafe { tree.insert(&foo, true) }.is_some());

        let found = tree.find(&bar);
        assert!(found.is_some());
        assert!(ptr::eq(found.unwrap(), &foo));
    }
}
